import pandas as pd
import numpy as np
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d


def load_D_values(path: str = "data/D_values.csv"):
    """
    Load fractal‑dimension anchor data from CSV.
    Expects columns: n, D[, sigma]
    Returns:
      n_vals (ndarray), D_vals (ndarray), sigma_vals (ndarray or None)
    """
    df = pd.read_csv(path)
    n_vals = df["n"].to_numpy()
    D_vals = df["D"].to_numpy()
    sigma_vals = df["sigma"].to_numpy() if "sigma" in df.columns else None
    return n_vals, D_vals, sigma_vals


def logistic_model(n, Dmin, Dmax, k, n0):
    """
    Four‑parameter logistic model:
      D(n) = Dmin + (Dmax − Dmin) / (1 + exp[−k (n − n0)])
    """
    return Dmin + (Dmax - Dmin) / (1.0 + np.exp(-k * (n - n0)))


def fit_fractal_curve(n_vals: np.ndarray,
                      D_vals: np.ndarray,
                      sigma_vals: np.ndarray = None):
    """
    Fit a logistic D(n) curve to data points.

    Args:
      n_vals: array of context indices
      D_vals: array of corresponding fractal dimensions
      sigma_vals: optional uncertainties for weighting

    Returns:
      popt: fitted parameters [Dmin, Dmax, k, n0]
      pcov: covariance matrix of the fit (or None if the fit failed)
    """
    # initial guess
    p0 = [np.min(D_vals), np.max(D_vals), 1.0, np.mean(n_vals)]
    try:
        popt, pcov = curve_fit(
            logistic_model,
            n_vals,
            D_vals,
            p0=p0,
            sigma=sigma_vals,
            absolute_sigma=(sigma_vals is not None),
            maxfev=10000
        )
    except RuntimeError:
        # if the fit fails to converge, return the initial guess
        popt = np.array(p0)
        pcov = None
    return popt, pcov


def logistic_D_function(n: np.ndarray, popt: np.ndarray) -> np.ndarray:
    """
    Compute D(n) for an array of n values using fitted logistic parameters.

    Args:
      n: array of context indices
      popt: parameters [Dmin, Dmax, k, n0]

    Returns:
      Array of dimension values D(n)
    """
    Dmin, Dmax, k, n0 = popt
    return logistic_model(n, Dmin, Dmax, k, n0)


def n_from_D(D_vals: np.ndarray,
             n_vals: np.ndarray,
             D_target: float) -> float:
    """
    Invert D(n) → n by linear interpolation over anchor points.
    Allows fractional context indices for a given dimension.

    Args:
      D_vals: array of fractal dimensions
      n_vals: corresponding context indices
      D_target: target dimension

    Returns:
      Interpolated context index n for D_target
    """
    inv = interp1d(D_vals, n_vals, kind="linear", fill_value="extrapolate")
    return float(inv(D_target))


def pivot_function(D_vals: np.ndarray,
                   a: float,
                   b: float) -> np.ndarray:
    """
    Compute pivot weights g(D) = a*D + b for an array of fractal dimensions.

    Args:
      D_vals: array of fractal dimensions
      a, b: affine pivot parameters

    Returns:
      Array of pivot weighting factors g(D)
    """
    return a * D_vals + b
